{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "tags": [
     "nbsphinx-gallery"
    ]
   },
   "source": [
    "# Expression trees in PyBaMM\n",
    "\n",
    "The basic data structure that PyBaMM uses to express models is an expression tree. This data structure encodes a tree representation of a given equation. The expression tree is used to encode the equations of both the original symbolic model, and the discretised equations of that model. Once discretised, the model equations are then passed to the solver, which must then evaluate the discretised expression trees in order to perform the time-stepping.\n",
    "\n",
    "The expression tree must therefore satisfy three requirements:\n",
    "1. To encode the model equations, it must be able to encode an arbitrary equation, including unary and binary operators such as `*`, `-`, spatial gradients or divergence, symbolic parameters, scalar, matrices and vectors.\n",
    "2. To perform the time-stepping, it must be able to be evaluated, given the current state vector $\\mathbf{y}$ and the current time $t$\n",
    "3. For solvers that require it, its gradient with respect to a given variable must be able to be evaluated (once again given $\\mathbf{y}$ and $t$)\n",
    "\n",
    "As an initial example, the code below shows how to construct an expression tree of the equation $2y(1 - y) + t$. We use the `pybamm.StateVector` to represent $\\mathbf{y}$, which in this case will be a vector of size 1. The time variable $t$ is already provided by PyBaMM and is of class `pybamm.Time`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install \"pybamm[plot,cite]\" -q    # install PyBaMM if it is not installed\n",
    "import numpy as np\n",
    "\n",
    "import pybamm\n",
    "\n",
    "y = pybamm.StateVector(slice(0, 1))\n",
    "t = pybamm.t\n",
    "equation = 2 * y * (1 - y) + t\n",
    "equation.visualise(\"expression_tree1.png\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![expression_tree1](expression_tree1.png)\n",
    "\n",
    "Once the equation is constructed, we can evaluate it at a given $t=1$ and $\\mathbf{y}=\\begin{pmatrix} 2 \\end{pmatrix}$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-3.]])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "equation.evaluate(1, np.array([2]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also calculate the expression tree representing the gradient of the equation with respect to $t$,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "diff_wrt_equation = equation.diff(t)\n",
    "diff_wrt_equation.visualise(\"expression_tree2.png\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![expression_tree2](expression_tree2.png)\n",
    "\n",
    "\n",
    "...and evaluate this expression,"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[-11.]])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "diff_wrt_equation.evaluate(t=1, y=np.array([2]), y_dot=np.array([2]))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The PyBaMM Pipeline\n",
    "\n",
    "Proposing, parameter setting and discretising a model in PyBaMM is a pipeline process, consisting of the following steps:\n",
    "\n",
    "1. The model is proposed, consisting of equations representing the right-hand-side of an ordinary differential equation (ODE), and/or algebraic equations for a differential algebraic equation (DAE), and also associated boundary condition equations\n",
    "2. The parameters present in the model are replaced by actual scalar values from a parameter file, using the [`pybamm.ParameterValues`](https://docs.pybamm.org/en/latest/source/api/parameters/parameter_values.html) class\n",
    "3. The equations in the model are discretised onto a mesh, any spatial gradients are replaced with linear algebra expressions and the variables of the model are replaced with state vector slices. This is done using the [`pybamm.Discretisation`](https://docs.pybamm.org/en/latest/source/api/spatial_methods/discretisation.html) class.\n",
    "\n",
    "## Stage 1 - Symbolic Expression Trees\n",
    "\n",
    "At each stage, the expression tree consists of certain types of nodes. In the first stage, the model is first proposed using [`pybamm.Parameter`](https://docs.pybamm.org/en/latest/source/api/expression_tree/parameter.html), [`pybamm.Variable`](https://docs.pybamm.org/en/latest/source/api/expression_tree/variable.html), and other [unary](https://docs.pybamm.org/en/latest/source/api/expression_tree/unary_operator.html) and [binary](https://docs.pybamm.org/en/latest/source/api/expression_tree/binary_operator.html) operators (which also includes spatial operators such as [`pybamm.Gradient`](https://docs.pybamm.org/en/latest/source/api/expression_tree/unary_operator.html#pybamm.Gradient) and [`pybamm.Divergence`](https://docs.pybamm.org/en/latest/source/api/expression_tree/unary_operator.html#pybamm.Divergence)). For example, the right hand side of the equation\n",
    "\n",
    "$$\\frac{d c}{dt} = D \\nabla \\cdot \\nabla c$$\n",
    "\n",
    "can be constructed as an expression tree like so:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "D = pybamm.Parameter(\"D\")\n",
    "c = pybamm.Variable(\"c\", domain=[\"negative electrode\"])\n",
    "\n",
    "dcdt = D * pybamm.div(pybamm.grad(c))\n",
    "dcdt.visualise(\"expression_tree3.png\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![expression_tree3](expression_tree3.png)\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Stage 2 - Setting parameters\n",
    "\n",
    "In the second stage, the `pybamm.ParameterValues` class is used to replace all the parameter nodes with scalar values, according to an input parameter file. For example, we'll use a this class to set $D = 2$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "parameter_values = pybamm.ParameterValues({\"D\": 2})\n",
    "dcdt = parameter_values.process_symbol(dcdt)\n",
    "dcdt.visualise(\"expression_tree4.png\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![expression_tree4](expression_tree4.png)\n",
    "\n",
    "## Stage 3 - Linear Algebra Expression Trees\n",
    "\n",
    "The third and final stage uses the `pybamm.Discretisation` class to discretise the spatial gradients and variables over a given mesh. After this stage the expression tree will encode a linear algebra expression that can be evaluated given the state vector $\\mathbf{y}$ and $t$.\n",
    "\n",
    "**Note:** for demonstration purposes, we use a dummy discretisation below. For a more complete description of the `pybamm.Discretisation` class, see the example notebook [here](https://github.com/pybamm-team/PyBaMM/blob/develop/docs/source/examples/notebooks/spatial_methods/finite-volumes.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here, we import a dummy discretisation from the PyBaMM tests directory.\n",
    "import sys\n",
    "\n",
    "sys.path.insert(0, pybamm.root_dir())\n",
    "from tests import get_discretisation_for_testing\n",
    "\n",
    "disc = get_discretisation_for_testing()\n",
    "disc.y_slices = {c: [slice(0, 40)]}\n",
    "dcdt = disc.process_symbol(dcdt)\n",
    "dcdt.visualise(\"expression_tree5.png\")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![expression_tree5](expression_tree5.png)\n",
    "\n",
    "After the third stage, our expression tree is now able to be evaluated by one of the solver classes. Note that we have used a single equation above to illustrate the different types of expression trees in PyBaMM, but any given models will consist of many RHS or algebraic equations, along with boundary conditions. See [here](https://github.com/pybamm-team/PyBaMM/tree/develop/docs/source/examples/notebooks/creating_models/) for more details of PyBaMM models."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "The relevant papers for this notebook are:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[1] Charles R. Harris, K. Jarrod Millman, Stéfan J. van der Walt, Ralf Gommers, Pauli Virtanen, David Cournapeau, Eric Wieser, Julian Taylor, Sebastian Berg, Nathaniel J. Smith, and others. Array programming with NumPy. Nature, 585(7825):357–362, 2020. doi:10.1038/s41586-020-2649-2.\n",
      "[2] Valentin Sulzer, Scott G. Marquis, Robert Timms, Martin Robinson, and S. Jon Chapman. Python Battery Mathematical Modelling (PyBaMM). Journal of Open Research Software, 9(1):14, 2021. doi:10.5334/jors.309.\n",
      "\n"
     ]
    }
   ],
   "source": [
    "pybamm.print_citations()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".env",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.11"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
